#!/usr/bin/env python3
import os
import argparse
import jsonlines
from agtr import AGTR


def get_hashes(file_path):
    """Returns a dict containing the md5 and pehash of the file.

    Arguments:
    file_path -- The path of the file to be hashed
    """

    if not os.path.isfile(file_path):
        raise ValueError("No such file: {}".format(file_path))

    md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()

    # Attempt to parse PE headers
    try:
        pe = pefile.PE(file_path, fast_load=True)
    except Exception:
        return {"md5": md5, "pehash": None}

    # Attempt to compute peHash
    pe_hash = pehash.totalhash(pe=pe)
    if pe_hash is None:
        return {"md5": md5, "pehash": None}

    return {"md5": md5, "pehash": pe_hash.hexdigest()}


def write_pehash_jsonl(args):
    """Writes jsonl file containing the md5 and pehash of each file in the
    provided directory.

    Arguments:
    args -- Command line arguments
    """

    # Get list of file paths
    file_paths = []
    for root, dirs, file_names in os.walk(args.agtr_dir):
        for file_name in file_names:
            file_paths.append(os.path.join(root, file_name))

    # Split file paths into batches
    num_files = len(file_paths)
    batch_size = min(1000, num_files)
    num_batches = (num_files + batch_size - 1) // batch_size
    batches = [file_paths[batch_num * batch_size:(batch_num + 1) * batch_size]
               for batch_num in range(num_batches)]

    if args.num_processes == -1:
        num_processes = os.cpu_count()
        pool = multiprocessing.Pool(args.num_processes)

    # Compute hashes for each batch of files
    num_processed = 0
    with jsonlines.open(args.jsonl_file, mode="w") as writer:
        for batch in batches:
            batch_hashes = pool.map(get_hashes, batch)
            writer.write_all(batch_hashes)
            num_processed += len(batch_hashes)
            if verbose:
                print("Processed {} total files".format(num_processed))

    return


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("pred_jsonl", help="The path to the .jsonl file " +
                        "containing the predicted labels. Format: " +
                        "{\"md5\": md5_val, \"label\": label_val}")
    parser.add_argument("--agtr-jsonl", help="The path to a .jsonl file " +
                        "containin the AGTR cluster labels. Format: " +
                        "{\"md5\": md5_val, \"pehash\": pehash_val}")
    parser.add_argument("--agtr-dir", help="A directory containing the " +
                        "files whose labels were predicted.")
    parser.add_argument("--num-processes", default=-1, type=int,
                        help="The maximum number of processes.") 
    parser.add_argument("--verbose", default=False, action="store_true")
    args = parser.parse_args()

    # Validate command line arguments
    if not os.path.isfile(args.pred_jsonl):
        raise ValueError("Invalid value of pred_jsonl")
    if not args.agtr_jsonl and not args.agtr_dir:
        raise ValueError("Must provide either --agtr-jsonl or --agtr-dir")
    if args.agtr_jsonl and not os.path.isfile(args.agtr_jsonl):
        raise ValueError("Invalid value of --agtr-jsonl")
    if args.agtr_dir and not os.path.isdir(args.agtr_dir):
        raise ValueError("Invalid value of --agtr-dir")

    # Write pehash .jsonl file if none provided
    if args.agtr_dir and not args.agtr_jsonl:
        args.agtr_jsonl = "pehash.jsonl"
        write_pehash_jsonl(args)

    # Construct AGTR
    agtr = AGTR(args.pred_jsonl, args.agtr_jsonl, 0.01,
                pred_format=("md5", "family"), agtr_format=("md5", "pehash"))

    # Print metric bounds
    print("Precision lower bound: {:.3f}".format(agtr.get_precision_bound()))
    print("Recall upper bound: {:.3f}".format(agtr.get_recall_bound()))
    print("Accuracy upper bound: {:.3f}".format(agtr.get_accuracy_bound()))
    print("Error rate lower bound: {:.3f}".format(agtr.get_error_rate_bound()))
